import bpy, bmesh
from mathutils import Vector, Color
import io
import math
import array
from s4studio.animation.blender import find_bone
from s4studio.blender import swizzle_uv, invalid_face, set_context, equals_float_array, blend_index_map, \
    approximate_vector, equals_vector, Sims4StudioException, apply_all_modifiers
from s4studio.cas.catalog import CASPart, BlendData
from s4studio.cas.geometry import BlendGeometry, BodyGeometry
from s4studio.core import ResourceKey
from s4studio.data.package import Package
from s4studio.helpers import FNV32, first
from s4studio.material import PackedPreset, Preset
from s4studio.material.blender import MaterialLoader
from s4studio.model import VisualProxy
from s4studio.model.geometry import Vertex

def save_geom(geom, mesh_object,body_type):
    apply_all_modifiers(mesh_object)
    assert isinstance(geom,BodyGeometry)
    vertex_group_map = {}
    vertex_groups = geom.bones
    set_context('EDIT', mesh_object)
    bpy.ops.mesh.reveal()
    bpy.ops.mesh.select_all(action='SELECT')
    bpy.ops.mesh.quads_convert_to_tris()
    bpy.ops.mesh.select_all(action='DESELECT')
    set_context('OBJECT', mesh_object)
    bpy.ops.object.transform_apply(location=True, rotation=True, scale=True)
    bpy.ops.transform.rotate(value=math.pi / 2.0, axis=(-1, 0, 0))
    bpy.ops.object.transform_apply(rotation=True)
    mesh_data = mesh_object.data
    try:
        mesh_data.calc_tangents()
    except RuntimeError as re:
        raise Sims4StudioException('Mesh Import Error','One or more of your meshes requires a UV map but does not have one.')
    
    # Establish bone groups
    for vertex_group_index in range(len(mesh_object.vertex_groups)):
        vertex_group = mesh_object.vertex_groups[vertex_group_index]
        hash = FNV32.hash(vertex_group.name)
        if not hash in vertex_groups:
            vertex_groups.append(hash)
        vertex_group_map[vertex_group.name] = vertex_groups.index(hash)
    pass

    geom_vertices = []
    indices = []

    # Vertex Index matched to a list of Sim Vertices split by UV coordinates
    vertex_map = {}

    # Enumerate blender loops to split vertices if their loop is different
    for loop_index, loop in enumerate(mesh_data.loops):
        face_point_index = loop_index % 3
        face_index = int((loop_index - face_point_index) / 3)

        # Add a triangle
        if face_point_index == 0:
            indices.append([])

        # Initialize vertex map for this index
        if not loop.vertex_index in vertex_map:
            vertex_map[loop.vertex_index] = []

        # Final vertex for face
        geom_vertex = None

        # Collect UV coordinates for all layers
        loop_uv = []
        for uv_layer_index, uv_layer in enumerate(mesh_data.uv_layers):
            uv = uv_layer.data[loop_index].uv
            loop_uv.append([uv[0], 1 - uv[1]])

        # Collect tangent vector for loop
        loop_tangent = [loop.tangent.x, loop.tangent.y, loop.tangent.z]

        # Collect normal vector for loop
        use_custom_normals = False

        bvmajor,bvminor,bvrevision = bpy.app.version
        if bvmajor >= 2 and bvminor >= 74 and use_custom_normals:
            loop_normal = [loop.normal.x, loop.normal.y,loop.normal.z]
        else:
            vertex_normal = mesh_data.vertices[loop.vertex_index].normal
            loop_normal = [vertex_normal.x, vertex_normal.y, vertex_normal.z]

        # Check for existing matching vertex
        for v in vertex_map[loop.vertex_index]:
            assert isinstance(v, Vertex)
            # Compare loop's UV to existing vertex
            uv_equal = True
            for i in range(len(v.uv)):
                if not equals_float_array(loop_uv[i], v.uv[i], 4):
                    uv_equal = False
                    break
            normal_equal= False
            if bvmajor >= 2 and bvminor >= 74 and use_custom_normals:
                normal_equal = equals_vector(loop_normal, v.normal)
            else:
                normal_equal = True
            if uv_equal and normal_equal:
                geom_vertex = v
                break
        # Create a new vertex if no matches are found
        if not geom_vertex:
            #Collect blender vertex
            vertex = mesh_data.vertices[loop.vertex_index]
            geom_vertex = Vertex()

            # Set position
            geom_vertex.position = [vertex.co.x, vertex.co.y, vertex.co.z]

            #Find original vertex
            # blender_normal= [vertex.normal.x, vertex.normal.y, vertex.normal.z]
            geom_vertex.normal = loop_normal # blender_normal
            # Set bone weights
            blend_index = [1] * 4
            weight_list = [0, 0, 0, 0]
            for i in range(min(len(vertex.groups), 4)):
                weight_list[blend_index_map[i]] = min(int(vertex.groups[i].weight * 255), 255)
                blend_index[i] = vertex_group_map[mesh_object.vertex_groups[int(vertex.groups[i].group)].name]
            geom_vertex.blend_indices = blend_index
            geom_vertex.blend_weights = weight_list

            # Set vertex color
            if len(mesh_data.vertex_colors):
                geom_vertex.colour= []
                for color_index in range(len(mesh_data.vertex_colors)):
                    blender_color = mesh_data.vertex_colors[color_index].data[loop_index].color[:]
                    sims_color = [int( c * 255) for c in blender_color]
                    sims_color.append(0)
                    geom_vertex.colour.append(sims_color)

            # Set loop data
            geom_vertex.uv = loop_uv
            geom_vertex.tangent = loop_tangent

            # Add new vertex to the main list
            geom_vertices.append(geom_vertex)
            vertex_map[loop.vertex_index].append(geom_vertex)
        indices[face_index].append(geom_vertices.index(geom_vertex))


    uv_stitch_data = geom.stitch_uvs()
    geom.indices = indices
    geom.vertices = geom_vertices
    geom.set_body_type(body_type)
    geom.uv_stitches = uv_stitch_data
    geom.bones = vertex_groups
    set_context('OBJECT', mesh_object)


def load_geom(name, geom, blend_vertex_lod_map, armature_rig, material):
    bvmajor,bvminor,bvrevision = bpy.app.version
    # Hash bone names
    bone_name_map = {}
    for bone in armature_rig.data.bones:
        bone_name_map[FNV32.hash(bone.name)] = bone.name

    # Create mesh and add to scene
    mesh_name = name
    mesh = bpy.data.meshes.new(mesh_name)
    mesh_obj = bpy.data.objects.new(mesh_name, mesh)
    if material:
        mesh_obj.data.materials.append(material)
    bpy.context.scene.objects.link(mesh_obj)
    bpy.context.scene.objects.active = mesh_obj
    # mesh_obj.parent = armature_rig


    # Add ShapeKeys if loading morphs
    if blend_vertex_lod_map:
        shape_key_basis = mesh_obj.shape_key_add(from_mix=False, name='Basis')
        for blend_name in blend_vertex_lod_map:
            shape_key_morph = mesh_obj.shape_key_add(from_mix=False, name=blend_name)
            shape_key_morph.relative_key = shape_key_basis
            shape_key_fcurve = shape_key_morph.driver_add('value')
            shape_key_fcurve.driver.type = 'AVERAGE'
            shape_key_var = shape_key_fcurve.driver.variables.new()
            shape_key_var.type = 'SINGLE_PROP'
            shape_key_var.targets[0].id_type = 'OBJECT'
            shape_key_var.targets[0].id = armature_rig
            shape_key_var.targets[0].data_path = 'pose.bones["%s"].morph_value' % blend_name

    mesh_uvs = []
    vertex_groups = []

    # Add Armature modifier and vertex groups to hold bone weights
    mesh_skin = mesh_obj.modifiers.new(type='ARMATURE', name="%s_skin" % mesh_name)
    mesh_skin.use_bone_envelopes = False
    mesh_skin.object = armature_rig
    for geom_bone in geom.bones:
        if not geom_bone in bone_name_map:
            pass
        else:
            vertex_groups.append(mesh_obj.vertex_groups.new(bone_name_map[geom_bone]))

    # Initialize BMesh
    bm = bmesh.new()
    bm.from_mesh(mesh)

    # Add vertices
    for vertex in geom.vertices:
        bm.verts.new(vertex.position)

    if bvmajor >=2 and bvminor >= 74:
        bm.verts.ensure_lookup_table()

    # Add faces
    faces_skipped = []
    for face_index, face in enumerate(geom.indices):
        if invalid_face(face):
            print('[%s]Face[%04i] %s has duplicate points, skipped' % (mesh_name, face_index, face))
            faces_skipped.append(face_index)
            continue
        f = [bm.verts[face_point] for face_point in face]
        try:
            bm.faces.new([f[0], f[1], f[2]])
        except ValueError as ve:
            faces_skipped.append(face_index)
            print(ve)

    # Convert to Mesh
    bm.normal_update()
    bm.to_mesh(mesh)
    color_channels = 0
    for vertex_index, vertex in enumerate(geom.vertices):
        # Positions
        if vertex.colour:
            color_channels = max(color_channels,len(vertex.colour))
        blender_vertex = mesh.vertices[vertex_index]
        pos = blender_vertex.co.copy()
        if blend_vertex_lod_map:
            shape_key_basis.data[vertex_index].co = pos
        # Add Shape Keys
        for blend_name in sorted(blend_vertex_lod_map):
            morph_blend_vertices = blend_vertex_lod_map[blend_name]
            morph_pos = pos.copy()
            if vertex.id in morph_blend_vertices:
                morph_blend_vertex = morph_blend_vertices[vertex.id]
                if morph_blend_vertex.position:
                    morph_pos += Vector(morph_blend_vertex.position)
            mesh.shape_keys.key_blocks[blend_name].data[vertex_index].co = morph_pos

        # Add Vertex Weights
        if vertex.blend_indices:
            for blend_index, blend_bone_index in enumerate(vertex.blend_indices):
                if blend_bone_index <= len(vertex_groups):
                    try:
                        blend_vertex_group = vertex_groups[blend_bone_index]
                        weight = vertex.blend_weights[blend_index_map[blend_index]] / 255
                        if weight > 0.000:
                            blend_vertex_group.add((vertex_index,), weight, 'REPLACE')
                    except Exception as ex:
                        print(ex)
    normals = []
    uv_face_skipped =0
    for face_index, face in enumerate(geom.indices):
        if invalid_face(face) or face_index in faces_skipped:
            uv_face_skipped += 1
            continue
        face_normals = []
        for face_point_index, face_point_vertex_index in enumerate([face[0], face[1], face[2]]):
            vertex = geom.vertices[face_point_vertex_index]
            face_normals.append(vertex.normal)
            if vertex.uv:
                for uv_channel_index, uv_coord in enumerate(vertex.uv):
                    if (uv_channel_index + 1) > len(mesh_uvs):
                        mesh_uvs.append(mesh.uv_textures.new(name='uv_%i' % uv_channel_index))
                    mesh.uv_layers[uv_channel_index].data[
                        face_point_index + ( (face_index - uv_face_skipped ) * 3)].uv = swizzle_uv(uv_coord)
        normals.extend(face_normals)

    for i in range(color_channels):
        bpy.ops.mesh.vertex_color_add()
    for loop_index, loop in enumerate(mesh.loops):
        vertex = geom.vertices[loop.vertex_index]
        if vertex.colour:
            for color_index in range(len(vertex.colour)):
                mesh.vertex_colors[color_index].data[loop_index].color = Color([c / 255.0 for c in vertex.colour[color_index][:3]])


    set_context('EDIT', mesh_obj)
    bpy.ops.mesh.select_all(action='SELECT')
    bpy.ops.mesh.faces_shade_smooth()
    set_context('OBJECT', mesh_obj)

    use_custom_normals = False

    if bvmajor >=2 and bvminor >= 74 and use_custom_normals:
        mesh.use_auto_smooth = True
        mesh.show_edge_sharp = True
        mesh.normals_split_custom_set(normals)
        remove_doubles= True
        if remove_doubles:
            set_context('EDIT', mesh_obj)
            bpy.ops.mesh.select_all(action='SELECT')
            bpy.ops.mesh.remove_doubles()
        set_context('OBJECT', mesh_obj)

    bpy.ops.transform.rotate(value=math.pi / 2.0, axis=(1, 0, 0))
    bpy.ops.object.transform_apply(rotation=True)
    mesh_obj.select = False
    mesh_obj.active_shape_key_index = 0
    return mesh_obj


def load_slider(armature_rig):
    for mesh in filter(lambda x: x.type == 'MESH', armature_rig.children):
        pass
    pass


def load_cas(package, armature_rig, morphs=False, expressions=False):
    cas_parts = list(package.find_all_type(CASPart.ID))
    if len(cas_parts):
        for caspart_index in cas_parts:
            caspart = caspart_index.fetch(CASPart)
            load_caspart(caspart, package, armature_rig, morphs, expressions)
    else:
        load_caspart(None, package, armature_rig, morphs, expressions)

    pass


def load_caspart(caspart, package, armature_rig, morphs=False, expressions=False):
    print('Loading CASPart %s...' % caspart.resource_name)
    bgeo = {}
    preset = None
    loaded_morphs = []
    # Loads a BlendGeometry and add it to the dictionary.  If a key was already loaded, skip it.  Either key or index must be specified.
    # If a name is provided, it will be used, otherwise it will default to the package name or the instance id
    def load_bgeo(key=None, name=None, index=None):
        if index:
            assert isinstance(index, Package.IndexEntry)
            key = index.key
        print('Loading %s %s' % (key, index))
        if not key.i:
            print('Skipping invalid BlendGeometry %s:  Instance must not be 0.' % key)
            return
        if key in loaded_morphs:
            print('Skipping BlendGeometry %s: Already loaded.' % key)
            return
        try:
            if not index:
                if key.t == BlendData.ID:
                    blend_data = package.find_key(key).fetch(BlendData)
                    assert isinstance(blend_data, BlendData)
                    key = blend_data.blend_geometry.key
                index = package.find_key(key)
            if not index:
                print('Skipping BlendGeometry %s: Resource not found in package' % key)
                return

            assert isinstance(index, Package.IndexEntry)
            resource = index.fetch(BlendGeometry)
            assert isinstance(resource, BlendGeometry)
            if not name:
                name = resource.resource_name
            if not name:
                name = '%16X' % key.i
            bgeo[name] = resource
            loaded_morphs.append(key)

        except Exception as ex:
            print('Skipping BlendGeometry %s: Error loading' % key)
            print(ex)
            pass
        pass

    # Maps blend LOD vertex by it's id
    def map_blend_vertex(blend_lod):
        bvmap = {}
        for v in blend_lod.vertices:
            bvmap[v.id] = v
        return bvmap

    # Adds a bone to the skeleton
    def create_armature_bone(bone_name, parent_bone=None, min_bone=.001):
        set_context('EDIT', armature_rig)
        armature_bone = find_bone(armature_rig.data.edit_bones, bone_name)
        if not armature_bone:
            armature_bone = armature_rig.data.edit_bones.new(bone_name)
            armature_bone.use_connect = False
            armature_bone.tail = [0, min_bone, 0]
        if parent_bone:
            armature_bone.parent = armature_rig.data.edit_bones[parent_bone]
        set_context('POSE', armature_rig)
        return armature_bone

    if caspart:
        print('CASP found...')
        preset = package.find_key(ResourceKey(t=PackedPreset.ID, g=caspart.key.g, i=caspart.key.i))
        if preset:
            preset = preset.fetch(PackedPreset)
        elif any(caspart.presets):
            preset = caspart.presets[0]
        part_name = caspart.part_name
        assert isinstance(caspart, CASPart)
        vpxy = package.get_resource(key=caspart.sources[0].key, wrapper=VisualProxy)

        # Load standard morphs if specified with user friendly name
        if morphs:
            if caspart.blend_fat.key.i:
                load_bgeo(name='Fat', key=caspart.blend_fat.key)
            if caspart.blend_fit.key.i:
                load_bgeo(name='Fit', key=caspart.blend_fit.key)
            if caspart.blend_thin.key.i:
                load_bgeo(name='Thin', key=caspart.blend_thin.key)
            if caspart.blend_special.key.i:
                load_bgeo(name='Pregnant', key=caspart.blend_special.key)

    else:
        print('No CASP found, defaulting to first VPXY')
        vpxy = first(package.find_all_type(VisualProxy.ID))
        if vpxy:
            vpxy = vpxy.fetch(VisualProxy)
        assert isinstance(vpxy, VisualProxy)
        part_name = vpxy.resource_name

    print('Loading morphs...')
    for bgeo_index in package.find_all_type(BlendGeometry.ID):
        try:
            load_bgeo(index=bgeo_index)
        except Exception as ex:
            print("Unable to load morph %s" % bgeo_index)
            print(ex)
    if not preset:
        preset = Preset()
    ml = MaterialLoader(package, preset)
    lod_hi = first(vpxy.entries, lambda e: e.TYPE == VisualProxy.LodEntry.TYPE)
    assert isinstance(lod_hi, VisualProxy.LodEntry)

    # Arrange morph data for processing
    blend_vertex_lod_map = {}
    for blend_name in bgeo:
        cur_bgeo = bgeo[blend_name]
        for blend in cur_bgeo.blends:
            blend_vertex_lod_map[blend_name] = map_blend_vertex(blend.lods[lod_hi.index])

    # Load face morphs for animal meshes.  Loads any BodyGeometry matching the name 'Expression' as a morph
    if expressions:
        driver_root = 'b__DRIVERS__'
        create_armature_bone(driver_root)
        print('creating root: %s' % driver_root)
        for index in package.find_all_type(BodyGeometry.ID):
            assert isinstance(index, Package.IndexEntry)
            geom = index.fetch(BodyGeometry)
            assert isinstance(geom, BodyGeometry)
            if not 'Expressions' in geom.resource_name:
                continue
            blend_name = geom.resource_name[17:-2]
            create_armature_bone(blend_name, parent_bone=driver_root)
            blend_vertex_lod_map[blend_name] = map_blend_vertex(geom)
    meshes = []
    for lod_sub_index, geom in enumerate(package.find_key(item.key).fetch(BodyGeometry) for item in lod_hi.resources):
        material = ml.generate('%s_%i' % (part_name, lod_sub_index), geom.material)
        meshes.append(
            load_geom(part_name + '_' + str(lod_sub_index), geom, blend_vertex_lod_map, armature_rig, material))
    return meshes




